{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Document representation from BERT",
      "provenance": [
        {
          "file_id": "1hccaqNncyxDG32f5U1Qncz6ipWiLV0TQ",
          "timestamp": 1614125265907
        },
        {
          "file_id": "1d9KurXhXvrV-jo-x2f7DkZ40qx75YAh_",
          "timestamp": 1604694308174
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//corp/legal/patents/colab:dst_colab_notebook",
        "kind": "shared"
      },
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ED6tBdZtOjlU"
      },
      "source": [
        "# Document representation from BERT"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CqNm7ioGOgSm"
      },
      "source": [
        "Copyright 2021 Google Inc.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c1vLcDJINTGg"
      },
      "source": [
        "import collections\n",
        "import math\n",
        "import random\n",
        "import sys\n",
        "import time\n",
        "from typing import Dict, List, Tuple\n",
        "from sklearn.metrics import pairwise\n",
        "# Use Tensorflow 2.0\n",
        "import tensorflow as tf\n",
        "import numpy as np"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vfSIZaeaPHpZ",
        "colab": {
          "height": 53
        },
        "executionInfo": {
          "status": "ok",
          "timestamp": 1614125346371,
          "user_tz": 300,
          "elapsed": 155,
          "user": {
            "displayName": "Rob Srebrovic",
            "photoUrl": "",
            "userId": "06004353344935214283"
          }
        },
        "outputId": "c0bca557-2962-4f3b-a8f9-71be6d820897"
      },
      "source": [
        "# Set BigQuery application credentials\n",
        "from google.cloud import bigquery\n",
        "import os\n",
        "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"path/to/file.json\"\n",
        "\n",
        "project_id = \"your_bq_project_id\"\n",
        "bq_client = bigquery.Client(project=project_id)"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'# Set BigQuery application credentials\\nfrom google.cloud import bigquery\\nimport os\\nos.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"path/to/file.json\"\\n\\nproject_id = \"your_bq_project_id\"\\nbq_client = bigquery.Client(project=project_id)'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7BojUHDYrESY"
      },
      "source": [
        "# You will have to clone the BERT repo\n",
        "!test -d bert_repo || git clone https://github.com/google-research/bert bert_repo\n",
        "if not 'bert_repo' in sys.path:\n",
        "  sys.path += ['bert_repo']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QeoX7LfgPLGP"
      },
      "source": [
        "The BERT repo uses Tensorflow 1 and thus a few of the functions have been moved/changed/renamed in Tensorflow 2. In order for the BERT tokenizer to be used, one of the lines in the repo that was just cloned needs to be modified to comply with Tensorflow 2. Line 125 in the BERT tokenization.py file must be changed as follows:\n",
        "\n",
        "From => `with tf.gfile.GFile(vocab_file, \"r\") as reader:`\n",
        "\n",
        "To => `with tf.io.gfile.GFile(vocab_file, \"r\") as reader:`\n",
        "\n",
        "Once that is complete and the file is saved, the tokenization library can be imported."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HsSJXKPDPLXn"
      },
      "source": [
        "import tokenization"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JBqRRfigQxxK"
      },
      "source": [
        "# Load BERT"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kp2fx508lWBG"
      },
      "source": [
        "MAX_SEQ_LENGTH = 512\n",
        "MODEL_DIR = 'path/to/model'\n",
        "VOCAB = 'path/to/vocab'\n",
        "\n",
        "tokenizer = tokenization.FullTokenizer(VOCAB, do_lower_case=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sNf96pSxxXg2"
      },
      "source": [
        "model = tf.compat.v2.saved_model.load(export_dir=MODEL_DIR, tags=['serve'])\n",
        "model = model.signatures['serving_default']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-BWnaHqoT7db"
      },
      "source": [
        "# Mean pooling layer for combining\n",
        "pooling = tf.keras.layers.GlobalAveragePooling1D()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rtzZg5LESCxF"
      },
      "source": [
        "# Get a couple of Patents\n",
        "\n",
        "Here we do a simple query from the BigQuery patents data to collect the claims for a sample set of patents."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u3iTTJQ5SFba"
      },
      "source": [
        "# Put your publications here.\n",
        "test_pubs = (\n",
        "    'US-8000000-B2', 'US-2007186831-A1', 'US-2009030261-A1', 'US-10722718-B2'\n",
        ")\n",
        "\n",
        "js = r\"\"\"\n",
        "  // Regex to find the separations of the claims data\n",
        "  var pattern = new RegExp(/[.][\\\\s]+[0-9]+[\\\\s]*[.]/, 'g');\n",
        "  if (pattern.test(text)) {\n",
        "    return text.split(pattern);\n",
        "  }\n",
        "\"\"\"\n",
        "\n",
        "query = r'''\n",
        "  #standardSQL\n",
        "  CREATE TEMPORARY FUNCTION breakout_claims(text STRING) RETURNS ARRAY<STRING> \n",
        "  LANGUAGE js AS \"\"\"\n",
        "  {}\n",
        "  \"\"\"; \n",
        "\n",
        "  SELECT \n",
        "    pubs.publication_number, \n",
        "    title.text as title, \n",
        "    breakout_claims(claims.text) as claims\n",
        "  FROM `patents-public-data.patents.publications` as pubs,\n",
        "    UNNEST(claims_localized) as claims,\n",
        "    UNNEST(title_localized) as title\n",
        "  WHERE\n",
        "    publication_number in {}\n",
        "'''.format(js, test_pubs)\n",
        "\n",
        "df = bq_client.query(query).to_dataframe()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "height": 241
        },
        "id": "ORcVOefPsT0U",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1614011849900,
          "user_tz": 300,
          "elapsed": 309,
          "user": {
            "displayName": "Jay Yonamine",
            "photoUrl": "",
            "userId": "01949405773282057831"
          }
        },
        "outputId": "5299f3c1-b64e-4cbd-9206-273d1fb1d300"
      },
      "source": [
        "df.head()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>publication_number</th>\n",
              "      <th>title</th>\n",
              "      <th>claims</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>US-2009030261-A1</td>\n",
              "      <td>Drug delivery system</td>\n",
              "      <td>[1 . A drug delivery system comprising:\\n a ca...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>US-2007186831-A1</td>\n",
              "      <td>Sewing machine</td>\n",
              "      <td>[1 . A sewing machine comprising:\\n a needle b...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>US-8000000-B2</td>\n",
              "      <td>Visual prosthesis</td>\n",
              "      <td>[1. A visual prosthesis apparatus comprising:\\...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>US-10722718-B2</td>\n",
              "      <td>Systems and methods for treatment of dry eye</td>\n",
              "      <td>[What is claimed is: \\n     \\n       1. A meth...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "  publication_number  ...                                             claims\n",
              "0   US-2009030261-A1  ...  [1 . A drug delivery system comprising:\\n a ca...\n",
              "1   US-2007186831-A1  ...  [1 . A sewing machine comprising:\\n a needle b...\n",
              "2      US-8000000-B2  ...  [1. A visual prosthesis apparatus comprising:\\...\n",
              "3     US-10722718-B2  ...  [What is claimed is: \\n     \\n       1. A meth...\n",
              "\n",
              "[4 rows x 3 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NeFzKlMw1DQd"
      },
      "source": [
        "def get_bert_token_input(texts):\n",
        "  input_ids = []\n",
        "  input_mask = []\n",
        "  segment_ids = []\n",
        "\n",
        "  for text in texts:\n",
        "    tokens = tokenizer.tokenize(text)\n",
        "    if len(tokens) > MAX_SEQ_LENGTH - 2:\n",
        "      tokens = tokens[0:(MAX_SEQ_LENGTH - 2)]\n",
        "    tokens = ['[CLS]'] + tokens + ['[SEP]']\n",
        "\n",
        "\n",
        "    ids = tokenizer.convert_tokens_to_ids(tokens)\n",
        "    token_pad = MAX_SEQ_LENGTH - len(ids)\n",
        "    input_mask.append([1] * len(ids) + [0] * token_pad)\n",
        "    input_ids.append(ids + [0] * token_pad)\n",
        "    segment_ids.append([0] * MAX_SEQ_LENGTH)\n",
        "  \n",
        "  return {\n",
        "      'segment_ids': tf.convert_to_tensor(segment_ids, dtype=tf.int64),\n",
        "      'input_mask': tf.convert_to_tensor(input_mask, dtype=tf.int64),\n",
        "      'input_ids': tf.convert_to_tensor(input_ids, dtype=tf.int64),\n",
        "      'mlm_positions': tf.convert_to_tensor([], dtype=tf.int64)\n",
        "  }"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MlrVU10IOlSZ"
      },
      "source": [
        "docs_embeddings = []\n",
        "for _, row in df.iterrows():\n",
        "  inputs = get_bert_token_input(row['claims'])\n",
        "  response = model(**inputs)\n",
        "  avg_embeddings = pooling(\n",
        "      tf.reshape(response['encoder_layer'], shape=[1, -1, 1024]))\n",
        "  docs_embeddings.append(avg_embeddings.numpy()[0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DhF2-w2yU52U",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1614012215102,
          "user_tz": 300,
          "elapsed": 240,
          "user": {
            "displayName": "Jay Yonamine",
            "photoUrl": "",
            "userId": "01949405773282057831"
          }
        },
        "outputId": "c6148de6-f1c2-40c3-d75d-90cc0f4e0469"
      },
      "source": [
        "pairwise.cosine_similarity(docs_embeddings)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0.9999988 , 0.68387157, 0.83200616, 0.86913264],\n",
              "       [0.68387157, 1.0000013 , 0.7299322 , 0.73105675],\n",
              "       [0.83200616, 0.7299322 , 0.99999964, 0.9027555 ],\n",
              "       [0.86913264, 0.73105675, 0.9027555 , 0.9999996 ]], dtype=float32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TFWxL-IGU9-6",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1614012321633,
          "user_tz": 300,
          "elapsed": 227,
          "user": {
            "displayName": "Jay Yonamine",
            "photoUrl": "",
            "userId": "01949405773282057831"
          }
        },
        "outputId": "9fffcf1d-0c2c-4d84-eb8e-847d6054f125"
      },
      "source": [
        "docs_embeddings[0].shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1024,)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    }
  ]
}
